In this project, we will use transfer learning to predict the classes for a subset of images from the Caltech-101 dataset. The reason we use transfer learning is because the dataset has very few images (50 per class), so traditional machine learning techniques would not be as effective.
First, we need to import some library.
library(readr)
library(ggplot2)
library(dplyr)
library(methods)
library(stringi)
library(keras)
library(glmnet)
Then, we import the Caltech-101 dataset.
input_dir <- "dataset"
image_paths <- dir(input_dir, recursive = TRUE)
ext <- stri_match(image_paths, regex = "\\.([A-Za-z]+$)")[,2]
image_paths <- image_paths[stri_trans_tolower(ext) %in% c("jpg", "png", "jpeg")]
class_vector <- dirname(image_paths)
class_names <- levels(factor(class_vector))
n <- length(class_vector)
Z <- array(0, dim = c(n, 224, 224, 3))
y <- as.numeric(factor(class_vector)) - 1L
for (i in seq_len(n))
{
pt <- file.path(input_dir, image_paths[i])
image <- image_to_array(image_load(pt, target_size = c(224,224)))
Z[i,,,] <- array_reshape(image, c(1, dim(image)))
}
# permute
set.seed(1)
index <- sample(seq_len(nrow(Z)))
Z <- Z[index,,,]
y <- y[index]
The dataset contains 20 classes, each having about 50-60 images. Let’s look at some examples of each class:
par(mar = c(0,0,0,0))
par(mfrow = c(4, 5))
set.seed(1)
for (i in 0:19) {
plot(0,0,xlim=c(0,1),ylim=c(0,1),axes= FALSE,type = "n")
j <- sample(which(y == i), 1)
rasterImage(Z[j,,,]/255,0,0,1,1)
text(0.5, 0.1, class_names[i+1], cex = 2, col = "red")
}
Before we jump into transfer learning, let’s try to train a simple model first and see how well it performs! Z is a matrix containing our images. First split Z into train and valid set:
z_train_id <- sample(c("train", "valid"), nrow(Z), TRUE, prob = c(0.6, 0.4))
Z_train <- Z[z_train_id == "train",,,] # Note: X is a matrix
y_train <- to_categorical(y[z_train_id == "train"])
Fit a simple convolutional model to Z:
simple_model <- keras_model_sequential()
simple_model %>%
layer_conv_2d(filters = 16, kernel_size = c(3, 3),
input_shape = dim(Z_train)[-1],
padding = "same") %>%
layer_max_pooling_2d(pool_size = c(2, 2)) %>%
layer_activation(activation = "relu") %>%
layer_conv_2d(filters = 16, kernel_size = c(3, 3),
padding = "same") %>%
layer_max_pooling_2d(pool_size = c(2, 2)) %>%
layer_activation(activation = "relu") %>%
layer_conv_2d(filters = 16, kernel_size = c(3, 3),
padding = "same") %>%
layer_max_pooling_2d(pool_size = c(2, 2)) %>%
layer_activation(activation = "relu") %>%
layer_flatten() %>%
layer_dense(units = ncol(y_train)) %>%
layer_activation(activation = "softmax")
simple_model %>% compile(loss = 'categorical_crossentropy',
optimizer = optimizer_sgd(lr = 0.01, momentum = 0.8),
metrics = c('accuracy'))
simple_model
## Model
## ___________________________________________________________________________
## Layer (type) Output Shape Param #
## ===========================================================================
## conv2d_1 (Conv2D) (None, 224, 224, 16) 448
## ___________________________________________________________________________
## max_pooling2d_1 (MaxPooling2D) (None, 112, 112, 16) 0
## ___________________________________________________________________________
## activation_1 (Activation) (None, 112, 112, 16) 0
## ___________________________________________________________________________
## conv2d_2 (Conv2D) (None, 112, 112, 16) 2320
## ___________________________________________________________________________
## max_pooling2d_2 (MaxPooling2D) (None, 56, 56, 16) 0
## ___________________________________________________________________________
## activation_2 (Activation) (None, 56, 56, 16) 0
## ___________________________________________________________________________
## conv2d_3 (Conv2D) (None, 56, 56, 16) 2320
## ___________________________________________________________________________
## max_pooling2d_3 (MaxPooling2D) (None, 28, 28, 16) 0
## ___________________________________________________________________________
## activation_3 (Activation) (None, 28, 28, 16) 0
## ___________________________________________________________________________
## flatten_1 (Flatten) (None, 12544) 0
## ___________________________________________________________________________
## dense_1 (Dense) (None, 20) 250900
## ___________________________________________________________________________
## activation_4 (Activation) (None, 20) 0
## ===========================================================================
## Total params: 255,988
## Trainable params: 255,988
## Non-trainable params: 0
## ___________________________________________________________________________
Train the simple model
history <- simple_model %>% fit(Z_train, y_train, epochs = 10,
validation_split = 0.1)
plot(history)
Let’s look at how our simple model performs
simple_y_pred <- predict_classes(simple_model, Z)
tapply(y == simple_y_pred, z_train_id, mean)
## train valid
## 0.06636501 0.03800475
This sucks! The model barely learn anything at all. Let’s see if we can do better with transfer learning.
For the transfer learning task, we will use the ResNet50 pre-trained model, which was trained on over a million images from the ImageNet database. The network is 50 layers deep can classify images into 1000 object categories.
We will use ResNet50 with the last layer excluded to embed the images into a denser representation (2048 dimensional vectors). This will make our classification significantly easier.
Import pre-trained model, grab the second last layer
resnet50 <- application_resnet50(weights = 'imagenet', include_top = TRUE)
model_avg_pool <- keras_model(inputs = resnet50$input,
outputs = get_layer(resnet50, 'avg_pool')$output)
Embed image using the pre-trained model
X_embedded <- predict(model_avg_pool, x = imagenet_preprocess_input(Z), verbose = TRUE)
dim(X_embedded)
## [1] 1084 1 1 2048
X = drop(X_embedded)
dim(X)
## [1] 1084 2048
Now that we have embedded our images into 2048 dimensional vectors, we can perform classification taking in those vectors as input.
We split X into train and valid set, 60/40 split:
train_id <- sample(c("train", "valid"), nrow(X), TRUE, prob = c(0.6, 0.4))
X_train <- X[train_id == "train",] # Note: X is a matrix
y_train <- to_categorical(y[train_id == "train"])
Then we train a model on the embedded corpus:
model <- keras_model_sequential()
model %>%
layer_dense(units = 256, input_shape = ncol(X_train)) %>%
layer_activation(activation = "relu") %>%
layer_dropout(rate = 0.5) %>%
layer_dense(units = 256) %>%
layer_activation(activation = "relu") %>%
layer_dropout(rate = 0.5) %>%
layer_dense(units = ncol(y_train)) %>%
layer_activation(activation = "softmax")
model %>% compile(loss = 'categorical_crossentropy',
optimizer = optimizer_rmsprop(lr = 0.0005),
metrics = c('accuracy'))
history <- model %>%
fit(X_train, y_train, epochs = 10)
plot(history)
Here is the accuracy of the model on both train and validation sets:
y_pred <- predict_classes(model, X)
tapply(y == y_pred, train_id, mean)
## train valid
## 1.0000000 0.9689737
This works much better than the simple model! We can see which classes are more easily misclassified using the confusion matrix:
table(value = class_names[y + 1L], prediction = class_names[y_pred + 1L], train_id)
## , , train_id = train
##
## prediction
## value crab cup helicopter lobster lotus mandolin mayfly pigeon
## crab 38 0 0 0 0 0 0 0
## cup 0 36 0 0 0 0 0 0
## helicopter 0 0 51 0 0 0 0 0
## lobster 0 0 0 29 0 0 0 0
## lotus 0 0 0 0 42 0 0 0
## mandolin 0 0 0 0 0 23 0 0
## mayfly 0 0 0 0 0 0 29 0
## pigeon 0 0 0 0 0 0 0 29
## pizza 0 0 0 0 0 0 0 0
## platypus 0 0 0 0 0 0 0 0
## pyramid 0 0 0 0 0 0 0 0
## revolver 0 0 0 0 0 0 0 0
## rhino 0 0 0 0 0 0 0 0
## rooster 0 0 0 0 0 0 0 0
## saxophone 0 0 0 0 0 0 0 0
## schooner 0 0 0 0 0 0 0 0
## scissors 0 0 0 0 0 0 0 0
## windsor_chair 0 0 0 0 0 0 0 0
## wrench 0 0 0 0 0 0 0 0
## yin_yang 0 0 0 0 0 0 0 0
## prediction
## value pizza platypus pyramid revolver rhino rooster saxophone
## crab 0 0 0 0 0 0 0
## cup 0 0 0 0 0 0 0
## helicopter 0 0 0 0 0 0 0
## lobster 0 0 0 0 0 0 0
## lotus 0 0 0 0 0 0 0
## mandolin 0 0 0 0 0 0 0
## mayfly 0 0 0 0 0 0 0
## pigeon 0 0 0 0 0 0 0
## pizza 33 0 0 0 0 0 0
## platypus 0 22 0 0 0 0 0
## pyramid 0 0 38 0 0 0 0
## revolver 0 0 0 44 0 0 0
## rhino 0 0 0 0 33 0 0
## rooster 0 0 0 0 0 27 0
## saxophone 0 0 0 0 0 0 24
## schooner 0 0 0 0 0 0 0
## scissors 0 0 0 0 0 0 0
## windsor_chair 0 0 0 0 0 0 0
## wrench 0 0 0 0 0 0 0
## yin_yang 0 0 0 0 0 0 0
## prediction
## value schooner scissors windsor_chair wrench yin_yang
## crab 0 0 0 0 0
## cup 0 0 0 0 0
## helicopter 0 0 0 0 0
## lobster 0 0 0 0 0
## lotus 0 0 0 0 0
## mandolin 0 0 0 0 0
## mayfly 0 0 0 0 0
## pigeon 0 0 0 0 0
## pizza 0 0 0 0 0
## platypus 0 0 0 0 0
## pyramid 0 0 0 0 0
## revolver 0 0 0 0 0
## rhino 0 0 0 0 0
## rooster 0 0 0 0 0
## saxophone 0 0 0 0 0
## schooner 38 0 0 0 0
## scissors 0 27 0 0 0
## windsor_chair 0 0 38 0 0
## wrench 0 0 0 27 0
## yin_yang 0 0 0 0 37
##
## , , train_id = valid
##
## prediction
## value crab cup helicopter lobster lotus mandolin mayfly pigeon
## crab 34 0 0 1 0 0 0 0
## cup 0 18 0 0 0 0 0 0
## helicopter 0 0 36 0 0 0 0 0
## lobster 0 0 1 11 0 0 0 0
## lotus 0 0 0 0 24 0 0 0
## mandolin 0 0 0 0 0 19 0 0
## mayfly 0 0 0 0 0 0 11 0
## pigeon 0 0 0 0 0 0 0 16
## pizza 0 0 0 0 0 0 0 0
## platypus 1 0 0 0 0 0 0 0
## pyramid 0 0 0 0 0 0 0 0
## revolver 0 0 1 0 0 0 0 0
## rhino 0 0 0 0 0 0 0 0
## rooster 0 0 0 0 0 0 0 1
## saxophone 0 0 0 0 0 0 0 0
## schooner 0 0 0 0 0 0 0 0
## scissors 0 0 0 0 0 0 0 0
## windsor_chair 0 0 0 0 0 0 0 0
## wrench 0 0 0 1 0 0 0 0
## yin_yang 0 0 0 0 0 0 0 0
## prediction
## value pizza platypus pyramid revolver rhino rooster saxophone
## crab 0 0 0 0 0 0 0
## cup 0 0 0 0 0 1 0
## helicopter 0 0 0 1 0 0 0
## lobster 0 0 0 0 0 0 0
## lotus 0 0 0 0 0 0 0
## mandolin 0 0 0 0 0 0 0
## mayfly 0 0 0 0 0 0 0
## pigeon 0 0 0 0 0 0 0
## pizza 20 0 0 0 0 0 0
## platypus 0 11 0 0 0 0 0
## pyramid 0 0 19 0 0 0 0
## revolver 0 0 0 36 0 0 0
## rhino 0 0 0 0 26 0 0
## rooster 0 0 0 0 0 21 0
## saxophone 0 0 0 0 0 0 16
## schooner 0 0 0 0 0 0 0
## scissors 0 0 0 0 0 0 0
## windsor_chair 0 0 0 0 0 0 0
## wrench 0 0 0 0 0 0 0
## yin_yang 0 0 0 0 0 0 0
## prediction
## value schooner scissors windsor_chair wrench yin_yang
## crab 0 0 0 0 0
## cup 0 2 0 0 0
## helicopter 0 0 0 0 0
## lobster 0 0 0 0 0
## lotus 0 0 0 0 0
## mandolin 0 0 0 1 0
## mayfly 0 0 0 0 0
## pigeon 0 0 0 0 0
## pizza 0 0 0 0 0
## platypus 0 0 0 0 0
## pyramid 0 0 0 0 0
## revolver 0 0 0 1 0
## rhino 0 0 0 0 0
## rooster 0 0 0 0 0
## saxophone 0 0 0 0 0
## schooner 25 0 0 0 0
## scissors 0 12 0 0 0
## windsor_chair 0 0 18 0 0
## wrench 0 1 0 10 0
## yin_yang 0 0 0 0 23
Here are samples of the wrong predictions:
par(mfrow = c(4, 4))
id <- which(y_pred != y)
for (i in id) {
par(mar = rep(0, 4L))
plot(0,0,xlim=c(0,1),ylim=c(0,1),axes= FALSE,type = "n")
rasterImage(Z[i,,,] /255,0,0,1,1)
text(0.5, 0.1, label = class_names[y[i] + 1L], col = "green", cex=2)
text(0.5, 0.3, label = class_names[y_pred[i] + 1L], col = "red", cex=2)
}
Here be the most representative image for each class according to our model:
y_probs <- predict(model, X)
par(mfrow = c(4, 5))
for (i in 0:19) {
par(mar = rep(0, 4L))
plot(0,0,xlim=c(0,1),ylim=c(0,1),axes= FALSE,type = "n")
j <- order(y_probs[,i+1], decreasing = TRUE)[1]
rasterImage(Z[j,,,]/255,0,0,1,1)
text(0.5, 0.1, class_names[i+1], cex = 2, col = "salmon")
}
The results show that the transfer learning model was very effective. The images that were misclassified could even be misclassified by a human.
In conclusion, transfer learning has allowed us to build a customized model with very little data.